
from nlu_model.cls.model_pytorch.model_train import TrainModelPipeline
from nlu_model.cls.news_title.util import data_loader, pretrain_w2v, word2idx, TITLE_DICT

if __name__ == "__main__":
    x_train, x_test, y_train, y_test = data_loader()

    embedding_weights, imdb_w2v, word2idx_dic = pretrain_w2v(x_train)
    # print(x_train[0])
    x_train = word2idx(x_train, word2idx_dic, padding = 20)
    x_test = word2idx(x_test, word2idx_dic, padding = 20)
    # print(x_train[0])

    textCNN_param = {
        'vocab_size': len(word2idx_dic),
        'embed_dim': 300,
        'class_num': len(TITLE_DICT),
        "kernel_num": 16,
        "kernel_size": [1, 2, 3, 4, 5],
        "dropout": 0.5,
        "pre_word_embeds": embedding_weights
    }

    train_config = {
        "MODEL_CONF": textCNN_param,
        "batch_size": 64,
        "epoch":1
    }
    print(TITLE_DICT)
    train_model_pipeline = TrainModelPipeline(train_config)
    train_model_pipeline.call_train(x_train, y_train)
    train_model_pipeline.call_evaluate(x_train, y_train)